/*
 * Copyright 2010-2016 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.kotlin.codegen.coroutines

import org.jetbrains.kotlin.codegen.optimization.common.CustomFramesMethodAnalyzer
import org.jetbrains.kotlin.codegen.optimization.common.OptimizationBasicInterpreter
import org.jetbrains.kotlin.codegen.optimization.common.StrictBasicValue
import org.jetbrains.kotlin.codegen.optimization.common.insnListOf
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.tree.*
import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue
import org.jetbrains.org.objectweb.asm.tree.analysis.Frame
import org.jetbrains.org.objectweb.asm.tree.analysis.Interpreter

/**
 * In cases like:
 * NEW
 * DUP
 * LDC "First"
 * ASTORE 1
 * ASTORE 2
 * ASTORE 3
 * INVOKE suspensionPoint
 * ALOAD 3
 * ALOAD 2
 * ALOAD 1
 * LDC "Second"
 * INVOKESPECIAL <init>(String;String)

 * Replace store/load instruction with moving NEW/DUP after suspension point:
 * LDC "First"
 * ASTORE 1
 * INVOKE suspensionPoint
 * ALOAD 1
 * LDC "Second"
 * ASTORE 5
 * ASTORE 4
 * NEW
 * DUP
 * ALOAD 4
 * ASTORE 5
 * INVOKESPECIAL <init>(String)
 *
 * This is needed because later we spill this variables containing uninitialized object into fields -> leads to VerifyError
 * Note that this transformation changes semantics a bit (class <clinit> may be invoked by NEW instruction)
 * TODO: current implementation affects all store/loads of uninitialized objects, even valid ones:
 * MyClass(try { 1 } catch (e: Exception) { 0 }) // here uninitialized MyClass-object is being spilled before try-catch and then loaded
 *
 * How this works:
 * 1. For each invokespecial <init> determine if NEW uninitialized value was saved to local at least once
 * 2. If it wasn't then do nothing
 * 3. If it was then:
 *   - remove all relevant NEW/DUP/LOAD/STORE instructions
 *   - spill rest of constructor arguments to new local vars
 *   - generate NEW/DUP
 *   - restore constructor arguments
 */
internal fun processUninitializedStores(methodNode: MethodNode) {
    val interpreter = UninitializedNewValueMarkerInterpreter()
    val frames = CustomFramesMethodAnalyzer("fake", methodNode, interpreter, ::UninitializedNewValueFrame).analyze()

    for ((index, insn) in methodNode.instructions.toArray().withIndex()) {
        val frame = frames[index] ?: continue
        val uninitializedValue = frame.getUninitializedValueForConstructorCall(insn) ?: continue

        val copyUsages: Set<AbstractInsnNode> = interpreter.uninitializedValuesToCopyUsages[uninitializedValue.newInsn]!!
        assert(copyUsages.size > 0) { "At least DUP copy operation expected" }

        // Value generated by NEW wasn't store to local/field (only DUPed)
        if (copyUsages.size == 1) continue

        (copyUsages + uninitializedValue.newInsn).forEach {
            methodNode.instructions.remove(it)
        }

        val indexOfConstructorArgumentFromTopOfStack = Type.getArgumentTypes((insn as MethodInsnNode).desc).size
        val storedTypes = arrayListOf<Type>()
        var nextVarIndex = methodNode.maxLocals

        for (i in 0 until indexOfConstructorArgumentFromTopOfStack) {
            val value = frame.getStack(frame.stackSize - 1 - i)
            val type = value.type
            methodNode.instructions.insertBefore(insn, VarInsnNode(type.getOpcode(Opcodes.ISTORE), nextVarIndex))
            nextVarIndex += type.size
            storedTypes.add(type)
        }
        methodNode.maxLocals = Math.max(methodNode.maxLocals, nextVarIndex)

        methodNode.instructions.insertBefore(insn, insnListOf(
                TypeInsnNode(Opcodes.NEW, uninitializedValue.newInsn.desc),
                InsnNode(Opcodes.DUP)
        ))

        for (type in storedTypes.reversed()) {
            nextVarIndex -= type.size
            methodNode.instructions.insertBefore(insn, VarInsnNode(type.getOpcode(Opcodes.ILOAD), nextVarIndex))
        }
    }
}

private class UninitializedNewValue(
        val newInsn: TypeInsnNode, val internalName: String
) : StrictBasicValue(Type.getObjectType(internalName)) {
    override fun toString() = "UninitializedNewValue(internalName='$internalName')"
}

private class UninitializedNewValueFrame(nLocals: Int, nStack: Int) : Frame<BasicValue>(nLocals, nStack) {
    override fun execute(insn: AbstractInsnNode, interpreter: Interpreter<BasicValue>?) {
        val replaceTopValueWithInitialized = getUninitializedValueForConstructorCall(insn) != null

        super.execute(insn, interpreter)

        if (replaceTopValueWithInitialized) {
            // Drop top value
            val value = pop() as UninitializedNewValue

            // uninitialized value become initialized after <init> call
            push(StrictBasicValue(value.type))
        }
    }
}

/**
 * @return value generated by NEW that used as 0-th argument of constructor call or null if current instruction is not constructor call
 */
private fun Frame<BasicValue>.getUninitializedValueForConstructorCall(
        insn: AbstractInsnNode
): UninitializedNewValue? {
    if (!insn.isConstructorCall()) return null

    assert(insn.opcode == Opcodes.INVOKESPECIAL) { "Expected opcode Opcodes.INVOKESPECIAL for <init>, but ${insn.opcode} found" }
    val paramsCountIncludingReceiver = Type.getArgumentTypes((insn as MethodInsnNode).desc).size + 1
    val newValue = getStack(stackSize - (paramsCountIncludingReceiver + 1)) as? UninitializedNewValue ?: error("Expected value generated with NEW")

    assert(getStack(stackSize - paramsCountIncludingReceiver) is UninitializedNewValue) {
        "Next value after NEW should be one generated by DUP"
    }

    return newValue
}

private fun AbstractInsnNode.isConstructorCall() = this is MethodInsnNode && this.name == "<init>"

private class UninitializedNewValueMarkerInterpreter : OptimizationBasicInterpreter() {
    val uninitializedValuesToCopyUsages = hashMapOf<AbstractInsnNode, MutableSet<AbstractInsnNode>>()
    override fun newOperation(insn: AbstractInsnNode): BasicValue? {
        if (insn.opcode == Opcodes.NEW) {
            uninitializedValuesToCopyUsages.getOrPut(insn) { mutableSetOf() }
            return UninitializedNewValue(insn as TypeInsnNode, insn.desc)
        }
        return super.newOperation(insn)
    }

    override fun copyOperation(insn: AbstractInsnNode, value: BasicValue?): BasicValue? {
        if (value is UninitializedNewValue) {
            uninitializedValuesToCopyUsages[value.newInsn]!!.add(insn)
            return value
        }
        return super.copyOperation(insn, value)
    }

    override fun merge(v: BasicValue, w: BasicValue): BasicValue {
        if (v === w) return v
        if (v === StrictBasicValue.UNINITIALIZED_VALUE || w === StrictBasicValue.UNINITIALIZED_VALUE) {
            return StrictBasicValue.UNINITIALIZED_VALUE
        }

        if (v is UninitializedNewValue || w is UninitializedNewValue) {
            if ((v as? UninitializedNewValue)?.newInsn !== (w as? UninitializedNewValue)?.newInsn) {
                // Merge of two different ANEW result is possible, but such values should not be used further
                return StrictBasicValue.UNINITIALIZED_VALUE
            }

            return v
        }

        return super.merge(v, w)
    }
}
